dm_haiku==0.0.6
jax==0.3.23
jaxlib==0.3.22
matplotlib==3.5.1
numpy==1.23.3
optax==0.1.3
scikit_learn==1.1.2
scipy==1.10.1
tensorflow_probability==0.19.0
tqdm==4.64.1
